{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Importing Brevitas networks into FINN with the QONNX interchange format\n",
    "\n",
    "**Note: Previously it was possible to directly export the FINN-ONNX interchange format from Brevitas to pass to the FINN compiler. This support is deprecated and FINN uses the export to the QONNX format as a front end, internally FINN uses still the FINN-ONNX format.**\n",
    "\n",
    "In this notebook we'll go through an example of how to import a Brevitas-trained QNN into FINN. The steps will be as follows:\n",
    "\n",
    "1. Load up the trained PyTorch model\n",
    "2. Call Brevitas QONNX export and visualize with Netron\n",
    "3. Import into FINN and converting QONNX to FINN-ONNX\n",
    "\n",
    "We'll use the following utility functions to print the source code for function calls (`showSrc()`) and to visualize a network using netron (`showInNetron()`) in the Jupyter notebook:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import onnx\n",
    "from finn.util.visualization import showSrc, showInNetron"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load up the trained PyTorch model\n",
    "\n",
    "The FINN Docker image comes with several [example Brevitas networks](https://github.com/Xilinx/brevitas/tree/master/src/brevitas_examples/bnn_pynq), and we'll use the LFC-w1a1 model as the example network here. This is a binarized fully connected network trained on the MNIST dataset. Let's start by looking at what the PyTorch network definition looks like:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from brevitas_examples import bnn_pynq\n",
    "showSrc(bnn_pynq.models.FC)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that the network topology is constructed using a few helper functions that generate the quantized linear layers and quantized activations. The bitwidth of the layers is actually parametrized in the constructor, so let's instantiate a 1-bit weights and activations version of this network. We also have pretrained weights for this network, which we will load into the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from finn.util.test import get_test_model\n",
    "lfc = get_test_model(netname = \"LFC\", wbits = 1, abits = 1, pretrained = True)\n",
    "lfc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have now instantiated our trained PyTorch network. Let's try to run an example MNIST image through the network using PyTorch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from pkgutil import get_data\n",
    "import onnx\n",
    "import onnx.numpy_helper as nph\n",
    "raw_i = get_data(\"qonnx.data\", \"onnx/mnist-conv/test_data_set_0/input_0.pb\")\n",
    "input_tensor = onnx.load_tensor_from_string(raw_i)\n",
    "input_tensor_npy = nph.to_array(input_tensor)\n",
    "input_tensor_pyt = torch.from_numpy(input_tensor_npy).float()\n",
    "imgplot = plt.imshow(input_tensor_npy.reshape(28,28), cmap='gray')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.nn.functional import softmax\n",
    "# do forward pass in PyTorch/Brevitas\n",
    "produced = lfc.forward(input_tensor_pyt).detach()\n",
    "probabilities = softmax(produced, dim=-1).flatten()\n",
    "probabilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "objects = [str(x) for x in range(10)]\n",
    "y_pos = np.arange(len(objects))\n",
    "plt.bar(y_pos, probabilities, align='center', alpha=0.5)\n",
    "plt.xticks(y_pos, objects)\n",
    "plt.ylabel('Predicted Probability')\n",
    "plt.title('LFC-w1a1 Predictions for Image')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Call Brevitas QONNX export and visualize with Netron\n",
    "\n",
    "Brevitas comes with built-in QONNX export functionality. This is similar to the regular ONNX export capabilities of PyTorch, with a few differences:\n",
    "\n",
    "1. Weight and activation quantization is represented as a 'fake-quantization' with Quant and BipolarQuant nodes.\n",
    "2. Truncation operations as required by average pooling are represented with a Trunc node.\n",
    "\n",
    "One can read more about how QONNX works and why it was developed here: https://xilinx.github.io/finn//2021/11/03/qonnx-and-finn.html\n",
    "\n",
    "Additionally QONNX comes with a set of tools for working with the format. These are maintained together with the Fast Machinelearning collaboration as an open-source projet here: https://github.com/fastmachinelearning/qonnx\n",
    "\n",
    "It's actually quite straightforward to export QONNX from our Brevitas model as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from brevitas.export import export_qonnx\n",
    "export_onnx_path = \"/tmp/LFCW1A1_qonnx.onnx\"\n",
    "input_shape = (1, 1, 28, 28)\n",
    "export_qonnx(lfc, torch.randn(input_shape), export_onnx_path);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's examine what the exported ONNX model looks like. For this, we will use the Netron visualizer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "showInNetron(export_onnx_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When running this notebook in the FINN Docker container, you should be able to see an interactive visualization of the imported network above, and click on individual nodes to inspect their parameters. If you look at any of the MatMul nodes, you should be able to see that the weights are all {-1, +1} values."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Import into FINN and converting QONNX to FINN-ONNX\n",
    "\n",
    "We will first run a cleanup transformation on the exported QONNX model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qonnx.util.cleanup import cleanup\n",
    "\n",
    "export_onnx_path_cleaned = \"/tmp/LFCW1A1-qonnx-clean.onnx\"\n",
    "cleanup(export_onnx_path, out_file=export_onnx_path_cleaned)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "showInNetron(export_onnx_path_cleaned)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will now import this QONNX model into FINN using the ModelWrapper. Here we can immediatley execute the model to verify correctness."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qonnx.core.modelwrapper import ModelWrapper\n",
    "import qonnx.core.onnx_exec as oxe\n",
    "model = ModelWrapper(export_onnx_path_cleaned)\n",
    "input_dict = {\"global_in\": nph.to_array(input_tensor)}\n",
    "output_dict = oxe.execute_onnx(model, input_dict)\n",
    "produced_qonnx = output_dict[list(output_dict.keys())[0]]\n",
    "\n",
    "produced_qonnx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.isclose(produced, produced_qonnx).all()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using the `QONNXtoFINN` transformation we can convert the model to the FINN internal FINN-ONNX representation. Notably all Quant and BipolarQuant nodes will have disappeared and are converted into MultiThreshold nodes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN\n",
    "model = ModelWrapper(export_onnx_path_cleaned)\n",
    "\n",
    "model = model.transform(ConvertQONNXtoFINN())\n",
    "\n",
    "export_onnx_path_converted = \"/tmp/LFCW1A1-qonnx-converted.onnx\"\n",
    "model.save(export_onnx_path_converted)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "showInNetron(export_onnx_path_converted)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And once again we can execute the model with the FINN/QONNX execution engine."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ModelWrapper(export_onnx_path_converted)\n",
    "input_dict = {\"global_in\": nph.to_array(input_tensor)}\n",
    "output_dict = oxe.execute_onnx(model, input_dict)\n",
    "produced_finn = output_dict[list(output_dict.keys())[0]]\n",
    "\n",
    "produced_finn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.isclose(produced_qonnx, produced_finn).all()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have succesfully verified that the transformed and cleaned-up FINN graph still produces the same output, and can now use this model for further processing in FINN."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
